CS236605: Deep Learning on Computational Accelerators

Homework Assignment 3

Faculty of Computer Science, Technion.

Submitted by:

# Name Id email
Student 1 nevo agmon 203116769 nevoagmon@campus.technion.ac.il
Student 2 yuval shildan 205799307 yuvalshildan@campus.technion.ac.il

Introduction

In this assignment we'll learn to generate text with a deep multilayer RNN network based on GRU cells. Then we'll focus our attention on image generation and implement two different generative models: A variational autoencoder and a generative adversarial network.

General Guidelines

  • Please read the getting started page on the course website. It explains how to setup, run and submit the assignment.
  • This assignment requires running on GPU-enabled hardware. Please read the course servers usage guide. It explains how to use and run your code on the course servers to benefit from training with GPUs.
  • The text and code cells in these notebooks are intended to guide you through the assignment and help you verify your solutions. The notebooks do not need to be edited at all (unless you wish to play around). The only exception is to fill your name(s) in the above cell before submission. Please do not remove sections or change the order of any cells.
  • All your code (and even answers to questions) should be written in the files within the python package corresponding the assignment number (hw1, hw2, etc). You can of course use any editor or IDE to work on these files.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bb}[1]{\boldsymbol{#1}} $$

Part 1: Sequence Models

In this part we will learn about working with text sequences using recurrent neural networks. We'll go from a raw text file all the way to a fully trained GRU-RNN model and generate works of art!

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Text generation with a char-level RNN

Obtaining the corpus

Let's begin by downloading a corpus containing all the works of William Shakespeare. Since he was very prolific, this corpus is fairly large and will provide us with enough data for obtaining impressive results.

In [2]:
CORPUS_URL = 'https://github.com/cedricdeboom/character-level-rnn-datasets/raw/master/datasets/shakespeare.txt'
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')

def download_corpus(out_path=DATA_DIR, url=CORPUS_URL, force=False):
    pathlib.Path(out_path).mkdir(exist_ok=True)
    out_filename = os.path.join(out_path, os.path.basename(url))
    
    if os.path.isfile(out_filename) and not force:
        print(f'Corpus file {out_filename} exists, skipping download.')
    else:
        print(f'Downloading {url}...')
        with urllib.request.urlopen(url) as response, open(out_filename, 'wb') as out_file:
            shutil.copyfileobj(response, out_file)
        print(f'Saved to {out_filename}.')
    return out_filename
    
corpus_path = download_corpus()
Corpus file /home/nevoagmon/.pytorch-datasets/shakespeare.txt exists, skipping download.

Load the text into memory and print a snippet:

In [3]:
with open(corpus_path, 'r') as f:
    corpus = f.read()

print(f'Corpus length: {len(corpus)} chars')
print(corpus[7:1234])
Corpus length: 6347703 chars
ALLS WELL THAT ENDS WELL

by William Shakespeare

Dramatis Personae

  KING OF FRANCE
  THE DUKE OF FLORENCE
  BERTRAM, Count of Rousillon
  LAFEU, an old lord
  PAROLLES, a follower of Bertram
  TWO FRENCH LORDS, serving with Bertram

  STEWARD, Servant to the Countess of Rousillon
  LAVACHE, a clown and Servant to the Countess of Rousillon
  A PAGE, Servant to the Countess of Rousillon

  COUNTESS OF ROUSILLON, mother to Bertram
  HELENA, a gentlewoman protected by the Countess
  A WIDOW OF FLORENCE.
  DIANA, daughter to the Widow

  VIOLENTA, neighbour and friend to the Widow
  MARIANA, neighbour and friend to the Widow

  Lords, Officers, Soldiers, etc., French and Florentine  

SCENE:
Rousillon; Paris; Florence; Marseilles

ACT I. SCENE 1.
Rousillon. The COUNT'S palace

Enter BERTRAM, the COUNTESS OF ROUSILLON, HELENA, and LAFEU, all in black

  COUNTESS. In delivering my son from me, I bury a second husband.
  BERTRAM. And I in going, madam, weep o'er my father's death anew;
    but I must attend his Majesty's command, to whom I am now in
    ward, evermore in subjection.
  LAFEU. You shall find of the King a husband, madam; you, sir, a
    father. He that so generally is at all times good must of
    

Data Preprocessing

The first thing we'll need is to map from each unique character in the corpus to an index that will represent it in our learning process.

TODO: Implement the char_maps() function in the hw3/charnn.py module.

In [4]:
import hw3.charnn as charnn

char_to_idx, idx_to_char = charnn.char_maps(corpus)
print(char_to_idx)

test.assertEqual(len(char_to_idx), len(idx_to_char))
test.assertSequenceEqual(list(char_to_idx.keys()), list(idx_to_char.values()))
test.assertSequenceEqual(list(char_to_idx.values()), list(idx_to_char.keys()))
{'\n': 0, ' ': 1, '!': 2, '"': 3, '$': 4, '&': 5, "'": 6, '(': 7, ')': 8, ',': 9, '-': 10, '.': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21, ':': 22, ';': 23, '<': 24, '?': 25, 'A': 26, 'B': 27, 'C': 28, 'D': 29, 'E': 30, 'F': 31, 'G': 32, 'H': 33, 'I': 34, 'J': 35, 'K': 36, 'L': 37, 'M': 38, 'N': 39, 'O': 40, 'P': 41, 'Q': 42, 'R': 43, 'S': 44, 'T': 45, 'U': 46, 'V': 47, 'W': 48, 'X': 49, 'Y': 50, 'Z': 51, '[': 52, ']': 53, '_': 54, 'a': 55, 'b': 56, 'c': 57, 'd': 58, 'e': 59, 'f': 60, 'g': 61, 'h': 62, 'i': 63, 'j': 64, 'k': 65, 'l': 66, 'm': 67, 'n': 68, 'o': 69, 'p': 70, 'q': 71, 'r': 72, 's': 73, 't': 74, 'u': 75, 'v': 76, 'w': 77, 'x': 78, 'y': 79, 'z': 80, '}': 81, '\ufeff': 82}

Seems we have some strange characters in the corpus that are very rare and are probably due to mistakes. To reduce the length of each tensor we'll need to later represent our chars, it's best to remove them.

TODO: Implement the remove_chars() function in the hw3/charnn.py module.

In [5]:
corpus, n_removed = charnn.remove_chars(corpus, ['}','$','_','<','\ufeff'])
print(f'Removed {n_removed} chars')

# After removing the chars, re-create the mappings
char_to_idx, idx_to_char = charnn.char_maps(corpus)
Removed 34 chars

The next thing we need is an embedding of the chracters. An embedding is a representation of each token from the sequence as a tensor. For a char-level RNN, our tokens will be chars and we can thus use the simplest possible embedding: encode each char as a one-hot tensor. In other words, each char will be represented as a tensor whos length is the total number of unique chars (V) which contains all zeros except at the index corresponding to that specific char.

TODO: Implement the functions chars_to_onehot() and onehot_to_chars() in the hw3/charnn.py module.

In [6]:
# Wrap the actual embedding functions for calling convenience
def embed(text):
    return charnn.chars_to_onehot(text, char_to_idx)

def unembed(embedding):
    return charnn.onehot_to_chars(embedding, idx_to_char)

text_snippet = corpus[3104:3148]
print(text_snippet)
print(embed(text_snippet[0:3]))

test.assertEqual(text_snippet, unembed(embed(text_snippet)))
test.assertEqual(embed(text_snippet).dtype, torch.int8)
brine a maiden can season her praise in.
   
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]], dtype=torch.int8)

Dataset Creation

We wish to train our model to generate text by constantly predicting what the next char should be based on the past. To that end we'll need to train our recurrent network in a way similar to a classification task. At each timestep, we input a char and set the expected output (label) to be the next char in the original sequence.

We will split our corpus into shorter sequences of length S chars (try to think why; see question below). Each sample we provide our model with will therefore be a tensor of shape (S,V) where V is the embedding dimension. Our model will operate sequentially on each char in the sequence. For each sample, we'll also need a label. This is simple another sequence, shifted by one char so that the label of each char is the next char in the corpus.

TODO: Implement the chars_to_labelled_samples() function in the hw3/charnn.py module.

In [7]:
# Create dataset of sequences
seq_len = 64
vocab_len = len(char_to_idx)

# Create labelled samples
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
print(f'samples shape: {samples.shape}')
print(f'labels shape: {labels.shape}')

# Test shapes
num_samples = (len(corpus) - 1) // seq_len
test.assertEqual(samples.shape, (num_samples, seq_len, vocab_len))
test.assertEqual(labels.shape, (num_samples, seq_len))

# Test content
for _ in range(1000):
    # random sample
    i = np.random.randint(num_samples, size=(1,))[0]
    # Compare to corpus
    test.assertEqual(unembed(samples[i]), corpus[i*seq_len:(i+1)*seq_len], msg=f"content mismatch in sample {i}")
    # Compare to labels
    sample_text = unembed(samples[i])
    label_text = str.join('', [idx_to_char[j.item()] for j in labels[i]])
    test.assertEqual(sample_text[1:], label_text[0:-1], msg=f"label mismatch in sample {i}")
    
print(f'sample 100 as text:\n{unembed(samples[100])}')
samples shape: torch.Size([99182, 64, 78])
labels shape: torch.Size([99182, 64])
sample 100 as text:
nity, though valiant in the
    defence, yet is weak. Unfold to 

As usual, instead of feeding one sample as a time into our model's forward we'll work with batches of samples. This means that at every timestep, our model will operate on a batch of chars that are from different sequences. Effectively this will allow us to parallelize training our model by dong matrix-matrix multiplications instead of matrix-vector during the forward pass.

Let's use the standard PyTorch Dataset/DataLoader combo. Luckily for the dataset we can use a built-in class, TensorDataset to return tuples of (sample, label) from the samples and labels tensors we created above.

In [8]:
import torch.utils.data

# Create DataLoader returning batches of samples.
batch_size = 32

ds_corpus = torch.utils.data.TensorDataset(samples, labels)
dl_corpus = torch.utils.data.DataLoader(ds_corpus, batch_size=batch_size, shuffle=False)

Let's see what that gives us:

In [9]:
print(f'num batches: {len(dl_corpus)}')

x0, y0 = next(iter(dl_corpus))
print(f'shape of a batch sample: {x0.shape}')
print(f'shape of a batch label: {y0.shape}')
num batches: 3100
shape of a batch sample: torch.Size([32, 64, 78])
shape of a batch label: torch.Size([32, 64])

Model Implementation

Finally, our data set is ready so we can focus on our model.

We'll implement here is a multilayer gated recurrent unit (GRU) model, with dropout. This model is a type of RNN which performs similar to the well-known LSTM model, but it's somewhat easier to train because it has less parameters. We'll modify the regular GRU slightly by applying dropout to the hidden states passed between layers of the model.

The model accepts an input $\mat{X}\in\set{R}^{S\times V}$ containing a sequence of embedded chars. It returns an output $\mat{Y}\in\set{R}^{S\times V}$ of predictions for the next char and the final hidden state $\mat{H}\in\set{R}^{L\times H}$. Here $S$ is the sequence length, $V$ is the vocabulary size (number of unique chars), $L$ is the number of layers in the model and $H$ is the hidden dimension.

Mathematically, the model's forward function at layer $k\in[1,L]$ and timestep $t\in[1,S]$ can be described as

$$ \begin{align} \vec{z_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xz}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hz}}}^{[k]} + \vec{b}_{\mathrm{z}}^{[k]}\right) \\ \vec{r_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xr}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hr}}}^{[k]} + \vec{b}_{\mathrm{r}}^{[k]}\right) \\ \vec{g_t}^{[k]} &= \tanh\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xg}}}^{[k]} + (\vec{r_t}^{[k]}\odot\vec{h}_{t-1}^{[k]}) {\mattr{W}_{\mathrm{hg}}}^{[k]} + \vec{b}_{\mathrm{g}}^{[k]}\right) \\ \vec{h_t}^{[k]} &= \vec{z}^{[k]}_t \odot \vec{h}^{[k]}_{t-1} + \left(1-\vec{z}^{[k]}_t\right)\odot \vec{g_t}^{[k]} \end{align} $$

The input to each layer is, $$ \mat{X}^{[k]} = \begin{bmatrix} {\vec{x}_1}^{[k]} \ \vdots \ {\vec{x}_S}^{[k]}

\end{bmatrix}

\begin{cases} \mat{X} & \mathrm{if} ~k = 1~ \\ \mathrm{dropout}_p \left( \begin{bmatrix} {\vec{h}_1}^{[k-1]} \\ \vdots \\ {\vec{h}_S}^{[k-1]} \end{bmatrix} \right) & \mathrm{if} ~1 < k \leq L+1~ \end{cases}. $$

The output of the entire model is then, $$ \mat{Y} = \mat{X}^{[L+1]} {\mattr{W}_{\mathrm{hy}}} + \mat{B}_{\mathrm{y}} $$

and the final hidden state is $$ \mat{H} = \begin{bmatrix} {\vec{h}_S}^{[1]} \\ \vdots \\ {\vec{h}_S}^{[L]} \end{bmatrix}. $$

Notes:

  • $t\in[1,S]$ is the timestep, i.e. the current position within the sequence of each sample.
  • $\vec{x}_t^{[k]}$ is the input of layer $k$ at timestep $t$, respectively.
  • The outputs of the last layer $\vec{y}_t^{[L]}$, are the predicted next characters for every input char. These are similar to class scores in classification tasks.
  • The hidden states at the last timestep, $\vec{h}_S^{[k]}$, are the final hidden state returned from the model.
  • $\sigma(\cdot)$ is the sigmoid function, i.e. $\sigma(\vec{z}) = 1/(1+e^{-\vec{z}})$ which returns values in $(0,1)$.
  • $\tanh(\cdot)$ is the hyperbolic tangent, i.e. $\tanh(\vec{z}) = (e^{2\vec{z}}-1)/(e^{2\vec{z}}+1)$ which returns values in $(-1,1)$.
  • $\vec{h_t}^{[k]}$ is the hidden state of layer $k$ at time $t$. This can be thought of as the memory of that layer.
  • $\vec{g_t}^{[k]}$ is the candidate hidden state for time $t+1$.
  • $\vec{z_t}^{[k]}$ is known as the update gate. It combines the previous state with the input to determine how much the current state will be combined with the new candidate state. For example, if $\vec{z_t}^{[k]}=\vec{1}$ then the current input has no effect on the output.
  • $\vec{r_t}^{[k]}$ is known as the reset gate. It combines the previous state with the input to determine how much of the previous state will affect the current state candidate. For example if $\vec{r_t}^{[k]}=\vec{0}$ the previous state has no effect on the current candidate state.

Here's a graphical representation of the GRU's forward pass at each timestep. The $\vec{\tilde{h}}$ in the image is our $\vec{g}$ (candidate next state).

You can see how the reset and update gates allow the model to completely ignore it's previous state, completely ignore it's input, or any mixture of those states (since the gates are actually continuous and between $(0,1)$).

Here's a graphical representation of the entire model. You can ignore the $c_t^{[k]}$ (cell state) variables (which are relevant for LSTM models). Our model has only the hidden state, $h_t^{[k]}$. Also notice that we added dropout between layers (the up arrows).

The purple tensors are inputs (a sequence and initial hidden state per layer), and the green tensors are outputs (another sequence and final hidden state per layer). Each blue block implements the above forward equations. Blocks that are on the same vertical level are at the same layer, and therefore share parameters.

TODO: Implement the MultilayerGRU class in the hw3/charnn.py module.

Notes:

  • You'll need to handle input batches now. The math is identical to the above, but all the tensors will have an extra batch dimension as their first dimension.
  • Use the diagram above to help guide your implementation. It will help you visualize what shapes to returns where, etc.
In [10]:
in_dim = vocab_len
h_dim = 256
n_layers = 2
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers)
model = model.to(device)
print(model)

# Test forward pass
y, h = model(x0.to(dtype=torch.float))
print(f'y.shape={y.shape}')
print(f'h.shape={h.shape}')

test.assertEqual(y.shape, (batch_size, seq_len, vocab_len))
test.assertEqual(h.shape, (batch_size, n_layers, h_dim))
test.assertEqual(len(list(model.parameters())), 9 * n_layers + 2) 
MultilayerGRU(
  (param_l0_p0): Linear(in_features=78, out_features=1, bias=True)
  (param_l0_p1): Linear(in_features=256, out_features=1, bias=False)
  (param_l0_p2): Linear(in_features=78, out_features=1, bias=True)
  (param_l0_p3): Linear(in_features=256, out_features=1, bias=False)
  (param_l0_p4): Linear(in_features=78, out_features=256, bias=True)
  (param_l0_p5): Linear(in_features=256, out_features=256, bias=False)
  (param_l0_p6): Dropout(p=0)
  (param_l1_p0): Linear(in_features=256, out_features=1, bias=True)
  (param_l1_p1): Linear(in_features=256, out_features=1, bias=False)
  (param_l1_p2): Linear(in_features=256, out_features=1, bias=True)
  (param_l1_p3): Linear(in_features=256, out_features=1, bias=False)
  (param_l1_p4): Linear(in_features=256, out_features=256, bias=True)
  (param_l1_p5): Linear(in_features=256, out_features=256, bias=False)
  (param_l1_p6): Dropout(p=0)
  (weights_hy): Linear(in_features=256, out_features=78, bias=True)
)
y.shape=torch.Size([32, 64, 78])
h.shape=torch.Size([32, 2, 256])

Generating text by sampling

Now that we have a model, we can implement text generation based on it. The idea is simple: At each timestep our model receives one char $x_t$ from the input sequence and outputs scores $y_t$ for what the next char should be. We'll convert these scores into a probability over each of the possible chars. In other words, for each input char $x_t$ we create a probability distribution for the next char conditioned on the current one and the state of the model (representing all previous inputs): $$p(x_{t+1}|x_t; \vec{h}_t).$$

Once we have such a distribution, we'll sample a char from it. This will be the first char of our generated sequence. Now we can feed this new char into the model, create another distribution, sample the next char and so on. Note that it's crucial to propagate the hidden state when sampling.

The important point however is how to create the distribution from the scores. One way, as we saw in previous ML tasks, is to use the softmax function. However, a drawback of softmax is that it can generate very diffuse (more uniform) distributions if the score values are very similar. When sampling, we would prefer to control the distributions and make them less uniform to increase the chance of sampling the char(s) with the highest scores compared to the others.

To control the variance of the distribution, a common trick is to add a hyperparameter $T$, known as the temperature to the softmax function. The class scores are simply scaled by $T$ before softmax is applied: $$ \mathrm{softmax}_T(\vec{y}) = \frac{e^{\vec{y}/T}}{\sum_k e^{y_k/T}} $$

A low $T$ will result in less uniform distributions and vice-versa.

TODO: Implement the hot_softmax() function in the hw3/charnn.py module.

In [11]:
scores = y[0,0,:].detach()
_, ax = plt.subplots(figsize=(15,5))

for t in reversed([0.3, 0.5, 1.0, 100]):
    ax.plot(charnn.hot_softmax(scores, temperature=t).cpu().numpy(), label=f'T={t}')
ax.set_xlabel('$x_{t+1}$')
ax.set_ylabel('$p(x_{t+1}|x_t)$')
ax.legend()

uniform_proba = 1/len(char_to_idx)
uniform_diff = torch.abs(charnn.hot_softmax(scores, temperature=100) - uniform_proba)
test.assertTrue(torch.all(uniform_diff < 1e-4))

TODO: Implement the generate_from_model() function in the hw3/charnn.py module.

In [12]:
for _ in range(3):
    text = charnn.generate_from_model(model, "foobar", 50, (char_to_idx, idx_to_char), T=0.5)
    print(text)
    test.assertEqual(len(text), 50)
foobarRl9vuOirg-2FJPxUb3vf0[GV0)!L[l)]jrQ;U:)[D7?7
foobarm TF3n0D2..8VfuEp0YQpOqn7,A33:i9AGCH.HcOrOxT
foobarVKZ3NfG dwn424(:HpMW[kim:&P:rO'Je0r
7U2n)Mdd

Training

To train such a model, we'll calculate the loss at each time step by comparing the predicted char to the actual char from our label. We can use cross entropy since per char it's similar to a classification problem. We'll then sum the losses over the sequence and back-propagate the gradients though time. Notice that the back-propagation algorithm will "visit" each layer's parameter tensors multiple times, so we'll accumulate gradients in parameters of the blocks. Luckily autograd will handle this part for us.

As usual, the first step of training will be to try and overfit a large model (many parameters) to a tiny dataset. Again, this is to ensure the model and training code are implemented correctly, i.e. that the model can learn.

For a generative model such as this, overfitting is slightly trickier than for for classification. What we'll aim to do is to get our model to memorize a specific sequence of chars, so that when given the first char in the sequence it will immediately spit out the rest of the sequence verbatim.

Let's create a tiny dataset to memorize.

In [13]:
# Pick a tiny subset of the dataset
subset_start, subset_end = 1001, 1005
ds_corpus_ss = torch.utils.data.Subset(ds_corpus, range(subset_start, subset_end))
dl_corpus_ss = torch.utils.data.DataLoader(ds_corpus_ss, batch_size=1, shuffle=False)

# Convert subset to text
subset_text = ''
for i in range(subset_end - subset_start):
    subset_text += unembed(ds_corpus_ss[i][0])
print(f'Text to "memorize":\n\n{subset_text}')
Text to "memorize":

TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

Now let's implement the first part of our training code.

TODO: Implement the train_epoch() and train_batch() methods of the RNNTrainer class in the hw3/training.py module. Note: Think about how to correctly handle the hidden state of the model between batches and epochs (for this specific task, i.e. text generation).

In [14]:
import torch.nn as nn
import torch.optim as optim
from hw3.training import RNNTrainer

torch.manual_seed(42)

lr = 0.01
num_epochs = 500

in_dim = vocab_len
h_dim = 128
n_layers = 2
loss_fn = nn.CrossEntropyLoss()
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

for epoch in range(num_epochs):
    epoch_result = trainer.train_epoch(dl_corpus_ss, verbose=False)
    
    # Every X epochs, we'll generate a sequence starting from the first char in the first sequence
    # to visualize how/if/what the model is learning.
    if epoch == 0 or (epoch+1) % 25 == 0:
        avg_loss = np.mean(epoch_result.losses)
        accuracy = np.mean(epoch_result.accuracy)
        print(f'\nEpoch #{epoch+1}: Avg. loss = {avg_loss:.3f}, Accuracy = {accuracy:.2f}%')
        
        generated_sequence = charnn.generate_from_model(model, subset_text[0],
                                                        seq_len*(subset_end-subset_start),
                                                        (char_to_idx,idx_to_char), T=0.1)
        # Stop if we've successfully memorized the small dataset.
        print(generated_sequence)
        if generated_sequence == subset_text:
            break

# Test successful overfitting
test.assertGreater(epoch_result.accuracy, 99)
test.assertEqual(generated_sequence, subset_text)
Epoch #1: Avg. loss = 3.819, Accuracy = 18.75%
Tos                                                        t                           o                t               t                     t    o                                                              t                                             

Epoch #25: Avg. loss = 0.230, Accuracy = 92.58%
TRAM. What I would not kiss.
    Faith, yes:
    Faith, yes:
    I would not kiss.
    Faith, yes:
    I would not in haste to horse.
    Faith, yes:
    I would not kiss.
    Faith, yes:
    Faith, yes:
    I would not in haste to horse.
    Faith, yes:
 

Epoch #50: Avg. loss = 0.025, Accuracy = 99.22%
TRAM. What would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not t

Epoch #75: Avg. loss = 0.053, Accuracy = 98.05%
TRAM. What would you what I would you what I would not tell you what I would not tell you what I would you what I would you what I would you what I would you what I would you what I would you what I would you what I would you what I would you what I would 

Epoch #100: Avg. loss = 0.003, Accuracy = 100.00%
TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

OK, so training works - we can memorize a short sequence. Next on the agenda is to split our full dataset into a training and test sets of batched sequences.

In [15]:
# Full dataset definition
vocab_len = len(char_to_idx)
seq_len = 64
batch_size = 256
train_test_ratio = 0.9
num_samples = (len(corpus) - 1) // seq_len
num_train = int(train_test_ratio * num_samples)

samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)

ds_train = torch.utils.data.TensorDataset(samples[:num_train], labels[:num_train])
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=False, drop_last=True)

ds_test = torch.utils.data.TensorDataset(samples[num_train:], labels[num_train:])
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=True)

print(f'Train: {len(dl_train):3d} batches, {len(dl_train)*batch_size*seq_len:7d} chars')
print(f'Test:  {len(dl_test):3d} batches, {len(dl_test)*batch_size*seq_len:7d} chars')
Train: 348 batches, 5701632 chars
Test:   38 batches,  622592 chars

We'll now train a much larger model on our large dataset. You'll need a GPU for this part.

The code blocks below will train the model and save checkpoints containing the training state and the best model parameters to a file. This allows you to stop training and resume it later from where you left.

Note that you can use the main.py script provided within the assignment folder to run this notebook from the command line as if it were a python script by using the run-nb subcommand. This allows you to train your model using this notebook without starting jupyter. You can combine this with srun or sbatch to run the notebook with a GPU on the course servers.

In [16]:
# Full training definition
lr = 0.001
num_epochs = 50

in_dim = out_dim = vocab_len
hidden_dim = 512
n_layers = 3
dropout = 0.5
checkpoint_file = 'checkpoints/rnn'
max_batches = 300
early_stopping = 5

model = charnn.MultilayerGRU(in_dim, hidden_dim, out_dim, n_layers, dropout)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

TODO:

  • Implement the fit() method of the Trainer class. You can reuse the implementation from HW2, but make sure to implement early stopping and checkpoints.
  • Implement the test_epoch() and test_batch() methods of the RNNTrainer class in the hw3/training.py module.
  • Run the following block to train.
In [17]:
from cs236605.plot import plot_fit

def post_epoch_fn(epoch, test_res, train_res, verbose):
    # Update learning rate
    scheduler.step(test_res.accuracy)
    # Sample from model to show progress
    if verbose:
        start_seq = "ACT I."
        generated_sequence = charnn.generate_from_model(
            model, start_seq, 100, (char_to_idx,idx_to_char), T=0.5
        )
        print(generated_sequence)

# Train, unless final checkpoint is found
checkpoint_file_final = f'{checkpoint_file}_final.pt'
if os.path.isfile(checkpoint_file_final):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    saved_state = torch.load(checkpoint_file_final, map_location=device)
    model.load_state_dict(saved_state['model_state'])
else:
    try:
        # Print pre-training sampling
        print(charnn.generate_from_model(model, "ACT I.", 100, (char_to_idx,idx_to_char), T=0.5))

        fit_res = trainer.fit(dl_train, dl_test, num_epochs, max_batches=max_batches,
                              post_epoch_fn=post_epoch_fn, early_stopping=early_stopping,
                              checkpoints=checkpoint_file, print_every=1)
        
        fig, axes = plot_fit(fit_res)
    except KeyboardInterrupt as e:
        print('\n *** Training interrupted by user')
ACT I.nXg,;ryvqq]
LE3ySiA!FV4iN?CbYB4d0Ul&91ePz PZW:PsC61Af0stFh.tebq:!ZmdpNvxXI7U1r!Fxwv;yo: 7TVBC

*** Loading checkpoint file checkpoints/rnn.pt
--- EPOCH 1/50 ---
train_batch (Avg. Loss 1.805, Accuracy 48.0): 100%|██████████| 348/348 [01:26<00:00,  4.00it/s]
test_batch (Avg. Loss 1.787, Accuracy 46.5): 100%|██████████| 38/38 [00:03<00:00, 11.88it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 1
ACT I.

KING RICHARD II:
And the like with so are their word my sow with the brother the slainted th
--- EPOCH 2/50 ---
train_batch (Avg. Loss 1.782, Accuracy 48.6): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.762, Accuracy 47.2): 100%|██████████| 38/38 [00:03<00:00, 11.83it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 2
ACT I.

KING RICHARD RORI:
But and and gaint,
With thou some so the king man you dest this shall be 
--- EPOCH 3/50 ---
train_batch (Avg. Loss 1.763, Accuracy 49.1): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.744, Accuracy 47.8): 100%|██████████| 38/38 [00:03<00:00, 11.84it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 3
ACT I.

CORIOLANUS:
You have as shall present in our ground the of the son her with a country this s
--- EPOCH 4/50 ---
train_batch (Avg. Loss 1.747, Accuracy 49.6): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.728, Accuracy 48.1): 100%|██████████| 38/38 [00:03<00:00, 11.76it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 4
ACT I.

KING RICHARD II:
Be me heart not of the son the towy see to disposed the to speak a more goo
--- EPOCH 5/50 ---
train_batch (Avg. Loss 1.734, Accuracy 49.9): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.716, Accuracy 48.5): 100%|██████████| 38/38 [00:03<00:00, 11.73it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 5
ACT I.

KING RICHARD II:
We have should so shall strange a part the should be see the prithee her sp
--- EPOCH 6/50 ---
train_batch (Avg. Loss 1.727, Accuracy 50.2): 100%|██████████| 348/348 [01:26<00:00,  4.02it/s]
test_batch (Avg. Loss 1.705, Accuracy 48.9): 100%|██████████| 38/38 [00:03<00:00, 11.85it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 6
ACT I.

KING RICHARD II:
The say shall not my that to free the she his were the this with the good c
--- EPOCH 7/50 ---
train_batch (Avg. Loss 1.717, Accuracy 50.4): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.699, Accuracy 49.0): 100%|██████████| 38/38 [00:03<00:00, 11.83it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 7
ACT I.

KING RICHARD II:
God the wise in the king the fornow the boint have a fair and my borance sh
--- EPOCH 8/50 ---
train_batch (Avg. Loss 1.708, Accuracy 50.6): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.690, Accuracy 49.3): 100%|██████████| 38/38 [00:03<00:00, 11.68it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 8
ACT I.
  CARIA:
So so thou art the that shall be not and were not his court the good the grace, and 
--- EPOCH 9/50 ---
train_batch (Avg. Loss 1.701, Accuracy 50.8): 100%|██████████| 348/348 [01:26<00:00,  4.00it/s]
test_batch (Avg. Loss 1.681, Accuracy 49.5): 100%|██████████| 38/38 [00:03<00:00, 11.87it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 9
ACT I.

DUKE II:
No man to the both we more flight the crown he let be with be the pertonse her let 
--- EPOCH 10/50 ---
train_batch (Avg. Loss 1.694, Accuracy 51.0): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.676, Accuracy 49.7): 100%|██████████| 38/38 [00:03<00:00, 11.74it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 10
ACT I.

KING RICHARD II:
I will not the tongue, but the prove the hand in the can of the poor the wo
--- EPOCH 11/50 ---
train_batch (Avg. Loss 1.688, Accuracy 51.2): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.669, Accuracy 49.9): 100%|██████████| 38/38 [00:03<00:00, 11.64it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 11
ACT I.

GLOUCESTER:
What was the can for the tood brother the son, behead be the brother to his fath
--- EPOCH 12/50 ---
train_batch (Avg. Loss 1.683, Accuracy 51.3): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.666, Accuracy 50.0): 100%|██████████| 38/38 [00:03<00:00, 11.77it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 12
ACT I.
False the true in the man a will poor the great the for the soul so him the rage her be me he
--- EPOCH 13/50 ---
train_batch (Avg. Loss 1.677, Accuracy 51.5): 100%|██████████| 348/348 [01:26<00:00,  4.02it/s]
test_batch (Avg. Loss 1.662, Accuracy 50.1): 100%|██████████| 38/38 [00:03<00:00, 11.84it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 13
ACT I.
She speak the are so the lord, for the reason the sake.

GLOUCESTER:
But the lady cousin the 
--- EPOCH 14/50 ---
train_batch (Avg. Loss 1.673, Accuracy 51.6): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.658, Accuracy 50.1): 100%|██████████| 38/38 [00:03<00:00, 11.75it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 14
ACT I.
This soul to the see the country the word and the death move of the grown here with a quarner
--- EPOCH 15/50 ---
train_batch (Avg. Loss 1.668, Accuracy 51.7): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.653, Accuracy 50.2): 100%|██████████| 38/38 [00:03<00:00, 11.77it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 15
ACT I.
SCENE II:
My lord, it shall be common and death of the soul the hand my poor his thing the gr
--- EPOCH 16/50 ---
train_batch (Avg. Loss 1.665, Accuracy 51.8): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.651, Accuracy 50.4): 100%|██████████| 38/38 [00:03<00:00, 11.70it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 16
ACT I. SCENE II:
He the court me with is the see she on the presently long of his servence where wor
--- EPOCH 17/50 ---
train_batch (Avg. Loss 1.661, Accuracy 51.9): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.647, Accuracy 50.5): 100%|██████████| 38/38 [00:03<00:00, 11.69it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 17
ACT I. SCENE II:
Who cannot with me live me here is shall the borne the heart the such them would no
--- EPOCH 18/50 ---
train_batch (Avg. Loss 1.657, Accuracy 52.0): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.646, Accuracy 50.6): 100%|██████████| 38/38 [00:03<00:00, 11.70it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 18
ACT I. SIRITES:
This this thing the tome since of thy spart of the more the more beswer from the tho
--- EPOCH 19/50 ---
train_batch (Avg. Loss 1.655, Accuracy 52.1): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.642, Accuracy 50.6): 100%|██████████| 38/38 [00:03<00:00, 11.77it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 19
ACT I. SCENE II:
Where was the fear the restenger the poor provise, the man are the reason confess t
--- EPOCH 20/50 ---
train_batch (Avg. Loss 1.652, Accuracy 52.2): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.640, Accuracy 50.7): 100%|██████████| 38/38 [00:03<00:00, 11.79it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 20
ACT I. SCENE
III:
Which in her forth a that what the king that with the sound that look the life and
--- EPOCH 21/50 ---
train_batch (Avg. Loss 1.648, Accuracy 52.3): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.637, Accuracy 50.8): 100%|██████████| 38/38 [00:03<00:00, 11.66it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 21
ACT I. SCENE II:
Good to me, my father's part of Rome, sir, thou shall be state the our consent.

SI
--- EPOCH 22/50 ---
train_batch (Avg. Loss 1.646, Accuracy 52.3): 100%|██████████| 348/348 [01:26<00:00,  4.00it/s]
test_batch (Avg. Loss 1.634, Accuracy 50.9): 100%|██████████| 38/38 [00:03<00:00, 11.79it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 22
ACT I. SCENE III:
I have he hath and a that come the soul to the to be the proffens a company the wo
--- EPOCH 23/50 ---
train_batch (Avg. Loss 1.646, Accuracy 52.3): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.634, Accuracy 50.9): 100%|██████████| 38/38 [00:03<00:00, 11.86it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 23
ACT I.
SCENE II:
Would he the too be this for the tongue to the reason the part we are a winds the f
--- EPOCH 24/50 ---
train_batch (Avg. Loss 1.642, Accuracy 52.4): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.631, Accuracy 51.1): 100%|██████████| 38/38 [00:03<00:00, 11.72it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 24
ACT I. SCENE II:
The hath the arms with the hand on the and deserve and our with the arm and the lov
--- EPOCH 25/50 ---
train_batch (Avg. Loss 1.640, Accuracy 52.5): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.629, Accuracy 51.0): 100%|██████████| 38/38 [00:03<00:00, 11.15it/s]
ACT I. SCENE II:
Gear the tongue the death the mother in the king the son the gates a for the mate, 
--- EPOCH 26/50 ---
train_batch (Avg. Loss 1.637, Accuracy 52.5): 100%|██████████| 348/348 [01:26<00:00,  3.89it/s]
test_batch (Avg. Loss 1.628, Accuracy 51.0): 100%|██████████| 38/38 [00:03<00:00, 11.77it/s]
ACT I. SCENE II:
Happy the king.

QUEEN ELIZABETH:
And the starn and the foul the of the ware the co
--- EPOCH 27/50 ---
train_batch (Avg. Loss 1.636, Accuracy 52.6): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.626, Accuracy 51.2): 100%|██████████| 38/38 [00:03<00:00, 11.86it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 27
ACT I.
SCENE II:
Why, and reason against the seas we have his great to confess hath with not the ton
--- EPOCH 28/50 ---
train_batch (Avg. Loss 1.634, Accuracy 52.6): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.624, Accuracy 51.2): 100%|██████████| 38/38 [00:03<00:00, 11.86it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 28
ACT I.
Scere is my lord:
He out the more in our soul the son a tack, and the comm not be a lords, of
--- EPOCH 29/50 ---
train_batch (Avg. Loss 1.632, Accuracy 52.7): 100%|██████████| 348/348 [01:26<00:00,  4.02it/s]
test_batch (Avg. Loss 1.623, Accuracy 51.2): 100%|██████████| 38/38 [00:03<00:00, 11.84it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 29
ACT I. SCENE II:
Then when he will so more.

KING RICHARD II:
I do the son and stand with thee of wh
--- EPOCH 30/50 ---
train_batch (Avg. Loss 1.630, Accuracy 52.7): 100%|██████████| 348/348 [01:26<00:00,  4.02it/s]
test_batch (Avg. Loss 1.620, Accuracy 51.2): 100%|██████████| 38/38 [00:03<00:00, 11.78it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 30
ACT I. SCENE I:
What he will not be do thy hand in the answer's rest thou the reason the flowers to 
--- EPOCH 31/50 ---
train_batch (Avg. Loss 1.628, Accuracy 52.8): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.621, Accuracy 51.3): 100%|██████████| 38/38 [00:03<00:00, 11.87it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 31
ACT I. SCENE II:
Can you good lord, and all the king the wish the stretements the head the confurert
--- EPOCH 32/50 ---
train_batch (Avg. Loss 1.627, Accuracy 52.8): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.618, Accuracy 51.3): 100%|██████████| 38/38 [00:03<00:00, 11.81it/s]
ACT I. SCENE II:
He will dear of the sall to speak the best this earth a soul a prove the our best t
--- EPOCH 33/50 ---
train_batch (Avg. Loss 1.626, Accuracy 52.8): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.617, Accuracy 51.3): 100%|██████████| 38/38 [00:03<00:00, 11.68it/s]
ACT I.
Shall he thou have the prove the son of the son so make your bark his own last a man then be 
--- EPOCH 34/50 ---
train_batch (Avg. Loss 1.624, Accuracy 52.9): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.614, Accuracy 51.4): 100%|██████████| 38/38 [00:03<00:00, 11.14it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 34
ACT I. SCENE II:
And look in him so more for the soul more in the man some to the brather to the joy
--- EPOCH 35/50 ---
train_batch (Avg. Loss 1.623, Accuracy 52.9): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.614, Accuracy 51.3): 100%|██████████| 38/38 [00:03<00:00, 11.76it/s]
ACT I. SCENE I:
What the brother that wear the fool the soul he not the death the father, the propar
--- EPOCH 36/50 ---
train_batch (Avg. Loss 1.621, Accuracy 53.0): 100%|██████████| 348/348 [01:26<00:00,  4.00it/s]
test_batch (Avg. Loss 1.613, Accuracy 51.4): 100%|██████████| 38/38 [00:03<00:00, 11.80it/s]
ACT I.
Shall hear of his former the high and thee my soul when what my consent me, the thing speak h
--- EPOCH 37/50 ---
train_batch (Avg. Loss 1.620, Accuracy 53.0): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.611, Accuracy 51.5): 100%|██████████| 38/38 [00:03<00:00, 11.74it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 37
ACT I. SCENE I:
My lord, when thou should not the word of for the worst so sleep the protection shal
--- EPOCH 38/50 ---
train_batch (Avg. Loss 1.618, Accuracy 53.0): 100%|██████████| 348/348 [01:26<00:00,  3.99it/s]
test_batch (Avg. Loss 1.608, Accuracy 51.5): 100%|██████████| 38/38 [00:03<00:00, 11.85it/s]
ACT I. SCENE II:
And we shall not me this more and the praise the part to the soul and a world in th
--- EPOCH 39/50 ---
train_batch (Avg. Loss 1.617, Accuracy 53.1): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.606, Accuracy 51.6): 100%|██████████| 38/38 [00:03<00:00, 11.83it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 39
ACT I. SCENE I:
The voices of our garments stopp and not with the war to the father that is the coun
--- EPOCH 40/50 ---
train_batch (Avg. Loss 1.616, Accuracy 53.1): 100%|██████████| 348/348 [01:26<00:00,  4.01it/s]
test_batch (Avg. Loss 1.607, Accuracy 51.6): 100%|██████████| 38/38 [00:03<00:00, 11.87it/s]
ACT I.
She will not the sight wither not the state the fair death such so happy to the forth well by
--- EPOCH 41/50 ---
train_batch (Avg. Loss 1.615, Accuracy 53.1): 100%|██████████| 348/348 [01:26<00:00,  4.02it/s]
test_batch (Avg. Loss 1.607, Accuracy 51.6): 100%|██████████| 38/38 [00:03<00:00, 11.79it/s]
ACT I. SCENE I:
I have is in the house be the word, if thou art the at a comes be the companion of w
--- EPOCH 42/50 ---
train_batch (Avg. Loss 1.614, Accuracy 53.1): 100%|██████████| 348/348 [01:21<00:00,  4.45it/s]
test_batch (Avg. Loss 1.606, Accuracy 51.7): 100%|██████████| 38/38 [00:02<00:00, 12.93it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 42
ACT I.
She the company to see the death the country, the friends with the shall not come me so fair 
--- EPOCH 43/50 ---
train_batch (Avg. Loss 1.614, Accuracy 53.1): 100%|██████████| 348/348 [01:18<00:00,  4.46it/s]
test_batch (Avg. Loss 1.605, Accuracy 51.7): 100%|██████████| 38/38 [00:02<00:00, 12.96it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 43
ACT I. QUEEN ELIZABETH:
Why is the service to in the man in the live the father, for in the way for 
--- EPOCH 44/50 ---
train_batch (Avg. Loss 1.612, Accuracy 53.2): 100%|██████████| 348/348 [01:18<00:00,  4.46it/s]
test_batch (Avg. Loss 1.602, Accuracy 51.7): 100%|██████████| 38/38 [00:02<00:00, 12.84it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 44
ACT I.
Stenrer the true at the with the san the sense the gands to the father of the courtesy the me
--- EPOCH 45/50 ---
train_batch (Avg. Loss 1.611, Accuracy 53.2): 100%|██████████| 348/348 [01:18<00:00,  4.39it/s]
test_batch (Avg. Loss 1.602, Accuracy 51.8): 100%|██████████| 38/38 [00:02<00:00, 13.08it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 45
ACT I. SCENE I:
And have when the service
The such a command'd the all the heart her thoughts flatte
--- EPOCH 46/50 ---
train_batch (Avg. Loss 1.610, Accuracy 53.2): 100%|██████████| 348/348 [01:18<00:00,  4.44it/s]
test_batch (Avg. Loss 1.601, Accuracy 51.8): 100%|██████████| 38/38 [00:02<00:00, 13.05it/s]
ACT I.
SCENE II:
The content the seal all the love and the company great serve money to the fortune,
--- EPOCH 47/50 ---
train_batch (Avg. Loss 1.609, Accuracy 53.3): 100%|██████████| 348/348 [01:18<00:00,  4.43it/s]
test_batch (Avg. Loss 1.601, Accuracy 51.8): 100%|██████████| 38/38 [00:02<00:00, 12.97it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 47
ACT I. SCENE I:
And and bear so he will see the war with his lord, and so the arms, bear the wind wa
--- EPOCH 48/50 ---
train_batch (Avg. Loss 1.608, Accuracy 53.3): 100%|██████████| 348/348 [01:18<00:00,  4.46it/s]
test_batch (Avg. Loss 1.601, Accuracy 51.9): 100%|██████████| 38/38 [00:02<00:00, 13.01it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 48
ACT I.
Scene II:
What is the matter and the father,
And my lord, we may not such a that thou our han
--- EPOCH 49/50 ---
train_batch (Avg. Loss 1.607, Accuracy 53.3): 100%|██████████| 348/348 [01:18<00:00,  4.44it/s]
test_batch (Avg. Loss 1.599, Accuracy 51.8): 100%|██████████| 38/38 [00:02<00:00, 13.02it/s]
ACT I. SCENE I:
But the peace,
The wind of God for her mind of the devil to from the sacise the son 
--- EPOCH 50/50 ---
train_batch (Avg. Loss 1.607, Accuracy 53.3): 100%|██████████| 348/348 [01:23<00:00,  3.92it/s]
test_batch (Avg. Loss 1.600, Accuracy 51.9): 100%|██████████| 38/38 [00:03<00:00, 11.65it/s]
ACT I. SCENE I:
Suffer the grace, and come the conceit with our being componting and the thing that 

Generating a work of art

Armed with our fully trained model, let's generate the next Hamlet! You should experiment with modifying the sampling temperature and see what happens.

TODO: Specify the generation parameters in the part1_generation_params() function within the hw3/answers.py module.

In [18]:
import hw3.answers

start_seq, temperature = hw3.answers.part1_generation_params()

generated_sequence = charnn.generate_from_model(
    model, start_seq, 10000, (char_to_idx,idx_to_char), T=temperature
)

print(generated_sequence)
Once upon a time, the soul the soul the soul of the soul the soul the soul the soul of the soul the soul the soul the soul the soul of the soul of the soul of the soul of the soul the soul of the soul the soul the soul the soul the soul the soul the soul the soul of the soul of the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul the soul of the soul the soul the soul the soul the soul of the soul of the soul of the soul of the soul the soul the soul the soul the soul the soul of the soul the soul of the soul and the soul and the soul the soul the soul the soul the prove the soul of the soul the soul the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul of the soul of the soul the soul the soul of the soul the soul and the soul of the soul the soul the soul the soul the soul of the soul the soul the soul the soul of the soul the soul the soul of the soul the soul the soul of the soul of the state the soul of the soul the soul the soul the soul the soul of the soul the soul the soul the soul of the soul the soul the soul the soul the soul the soul the soul of the soul of the soul of the soul the soul of the soul of the soul and the soul of the soul the soul the soul the soul the soul the soul of the soul the soul the soul of the soul the soul the soul the soul the soul the soul the soul the soul of the soul the soul of the soul the soul the soul the soul the soul of the soul the prove the soul the soul the soul of the soul the soul the soul the soul the soul the soul of the soul the soul the soul of the soul the soul of the soul of the soul of the soul the soul of the soul the soul of the soul the soul the soul of the soul the soul the soul the soul of the soul the soul of the soul the soul of the soul the soul the soul of the soul of the soul of the soul the soul the soul the soul of the soul of the soul of the soul of the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul of the soul of the soul of the soul of the soul the soul of the soul of the soul the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul of the soul the soul the soul the soul the soul the soul of the soul the soul of the soul of the soul of the soul the soul of the soul the soul the soul of the soul of the soul of the soul of the soul the soul the soul the soul the soul of the soul the soul the soul of the soul of the soul the soul the soul the soul the soul the soul of the soul the soul the soul of the soul of the soul the soul the soul of the soul the soul of the soul the soul the soul the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul of the soul of the soul of the soul the soul of the soul of the soul of the soul of the soul of the soul the soul of the soul of the soul the soul of the soul the soul the soul the soul the soul of the soul the soul the soul the soul the soul the soul the soul of the soul the soul the soul the soul the soul of the soul the soul the soul the soul the soul of the soul of the soul of the soul and the soul the soul of the soul the soul the soul the soul the soul the soul of the state the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul of the soul the soul the soul of the soul of the soul of the soul of the soul the soul the soul the soul of the soul of the soul the soul of the soul the soul the soul of the soul of the soul the soul the soul of the soul the soul the soul the soul the soul of the soul the soul the soul the soul the soul of the soul the soul of the soul the soul of the soul the soul of the soul the soul the soul the soul the soul of the soul the soul the soul and the soul the soul the soul the soul the soul the soul the soul of the soul the soul the soul of the soul the soul of the soul the soul the soul the soul the soul of the soul of the soul of the soul of the soul the soul the soul the soul the soul the soul the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul the soul the soul of the soul of the soul the soul the soul of the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul of the soul the soul the soul the soul of the soul the soul of the soul the soul the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul the soul the soul of the soul the soul of the soul the soul the soul of the soul the soul the soul the soul the soul of the soul the soul the soul of the soul the soul the soul the soul of the soul the soul of the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul of the soul the soul the soul the soul the soul the soul of the soul the soul of the world the soul the soul the soul the soul the soul of the soul of the soul the soul of the soul the soul of the soul of the soul of the soul of the soul the soul of the soul of the soul of the soul the soul the soul the soul of the soul of the soul the soul of the soul the soul the soul of the soul the soul of the soul the soul of the soul the soul the soul of the soul the soul the soul of the soul the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul of the soul the soul of the soul of the soul of the soul and the soul of the soul of the soul the soul of the soul the soul the soul the soul the soul the soul the soul of the soul and the soul the soul the soul the soul the soul of the soul the soul the soul of the soul the soul the soul of the soul of the soul of the soul the soul of the soul the soul the soul the soul the soul of the soul the soul the soul of the soul the soul the soul of the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul of the soul of the soul the soul the soul the soul of the soul the soul the soul the soul of the soul the soul of the soul of the soul the soul the soul of the soul the soul of the soul of the soul the soul of the soul the soul of the soul of the soul the soul the soul the soul the soul the soul of the soul of the soul of the soul the soul of the soul the soul the soul the soul of the soul of the soul the soul of the soul of the soul the soul of the soul the soul of the soul of the soul of the soul of the soul the soul of the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul the soul of the soul of the soul of the soul the soul of the soul the soul of the soul of the soul the soul the soul the soul the soul and the soul the soul of the soul the soul and the soul the soul the soul the soul of the soul the soul the soul the soul the soul the soul the soul the soul the soul the soul to the soul the soul of the soul the soul of the soul the soul the soul of the soul the soul the soul of the soul the soul of the soul the soul the soul the soul of the soul the soul of the soul of the soul the soul the soul of the soul the soul the soul the country the soul of the soul the soul and the soul the soul the soul the soul of the soul of the world the soul the soul of the soul the soul the soul the soul the soul of the soul the soul of the soul the soul the soul of the soul the soul the soul the soul the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul the soul of the soul the soul of the soul of the soul the soul the soul of the soul the soul the soul of the soul the soul the soul the soul the soul the soul the soul the soul of the soul the soul of the soul the soul the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul of the soul the soul of the soul of the soul the soul the soul of the soul the soul the soul the soul the soul of the soul the soul of the soul of the soul the soul of the soul of the soul the soul the soul of the soul the soul the soul of the soul the soul the soul the soul the soul the soul of the soul the soul the soul of the soul the soul the soul the soul the soul of the soul of the soul the soul the soul the soul the soul the soul of the soul the soul of the soul of the soul of the soul the soul the soul the soul of the soul of the soul of the soul the soul the soul the soul the soul the soul of the soul of the soul the soul the soul the soul of the soul the soul the soul of the soul the soul the soul the soul of the soul the soul the soul the soul the soul the soul the soul the soul of the soul of the soul of the soul of the soul the soul of the soul of the soul the soul of the soul the soul the soul of the soul of the country the soul the soul the soul of the soul of the soul of the soul the soul the state the soul the soul of the soul the soul of the country the soul the soul the soul of the soul the soul the soul of the soul the soul of the soul of the soul of the soul the soul the soul of the soul of the soul of the soul of the soul of the soul the soul the soul the soul the soul the soul the soul of the soul to the soul the soul of the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul of the soul the soul of the soul of the soul of the soul of the soul the soul of the soul the soul the soul the soul of the soul of the soul the soul the soul of the soul of the soul the soul the soul the soul the soul of the soul the soul of the soul of the soul of the soul of the soul of the soul the soul the soul of the soul o

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [19]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

Why do we split the corpus into sequences instead of training on the whole text?

In [20]:
display_answer(hw3.answers.part1_q1)

There is no reason to learn from a very long sequence of characters. For eample the 1000th character is not related to the 10th in the corpus. The characters are related to the sentence in wich they are part of, and this means that only the last few characters are relevant to the prediction.

Question 2

How is it possible that the generated text clearly shows memory longer than the sequence length?

In [21]:
display_answer(hw3.answers.part1_q2)

The generated text seems to have memory that is longer then the squence length because of the hidden leyer that is passed between the samples in the batch and between batches. It learns the connection between the characters and have longer memory then the size of one batch.

Question 3

Why are we not shuffling the order of batches when training?

In [22]:
display_answer(hw3.answers.part1_q3)

The order of batches is represents the form of the text. Because of the relations that we disscused in q2 we must keep the order of the batches as it is in the text.

Question 4

  1. Why do we lower the temperature for sampling (compared to the default of $1.0$ when training)?
  2. What happens when the temperature is very high and why?
  3. What happens when the temperature is very low and why?
In [23]:
display_answer(hw3.answers.part1_q4)

We can see that high temperture results with a distribution that is closer to uniform over the dictionary and a lower temp results in a distribution wich is closer to 1 for the character with highest probability. When training we want to make sure to test all possible options and train accordingly so we use higher temprature and when sampling we want the outcome to represent the data learned as closely as possible so we will use lower temp.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 2: Variational Autoencoder

In this part we will learn to generate new data using a special type of autoencoder model which allows us to sample from it's latent space. We'll implement and train a VAE and use it to generate new images.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Obtaining the dataset

Let's begin by downloading a dataset of images that we want to learn to generate. We'll use the Labeled Faces in the Wild (LFW) dataset which contains many labels faces of famous individuals.

We're going to train our generative model to generate a specific face, not just any face. Since the person with the most images in this dataset is former president George W. Bush, we'll set out to train a Bush Generator :)

However, if you feel adventurous and/or prefer to generate something else, feel free to edit the PART2_CUSTOM_DATA_URL variable in hw3/answers.py.

In [2]:
import cs236605.plot as plot
import cs236605.download
from hw3.answers import PART2_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236605.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/nevoagmon/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/nevoagmon/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/nevoagmon/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [3]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [4]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(10,5), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [5]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

The Variational Autoencoder

An autoencoder is a model which learns a representation of data in an unsupervised fashion (i.e without any labels). Recall it's general form from the lecture:

An autoencoder maps an instance $\bb{x}$ to a latent-space representation $\bb{z}$. It has an encoder part, $\Phi_{\bb{\alpha}}(\bb{x})$ (a neural net with parameters $\bb{\alpha}$) and a decoder part, $\Psi_{\bb{\beta}}(\bb{z})$ (a neural net with parameters $\bb{\beta}$).

While autoencoders can learn useful representations, generally it's hard to use them as generative models because there's no distribution we can sample from in the latent space. In other words, we have no way to choose a point $\bb{z}$ in the latent space such that $\Psi(\bb{z})$ will end up on the data manifold in the instance space.

The variational autoencoder (VAE), first proposed by Kingma and Welling, addresses this issue by taking a probabilistic perspective. Briefly, a VAE model can be described as follows.

We define, in Baysean terminology,

  • The prior distribution $p(\bb{Z})$ on points in the latent space.
  • The likelihood distribution of a sample $\bb{X}$ given a latent-space representation: $p(\bb{X}|\bb{Z})$.
  • The posterior distribution of points in the latent spaces given a specific instance: $p(\bb{Z}|\bb{X})$.
  • The evidence distribution $p(\bb{X})$ which is the distribution of the instance space due to the generative process.

To create our variational decoder we'll further specify:

  • A parametric likelihood distribution, $p _{\bb{\beta}}(\bb{X} | \bb{z}) = \mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$. The interpretation is that given a latent $\bb{z}$, we map it to a point normally distributed around the point calculated by our decoder neural network. Note that here $\sigma^2$ is a hyperparameter while $\vec{\beta}$ represents the network parameters.
  • A fixed latent-space prior distribution of $p(\bb{Z}) = \mathcal{N}(\bb{0},\bb{I})$.

This setting allows us to generate a new instance $\bb{x}$ by sampling $\bb{z}$ from the multivariate normal distribution, obtaining the instance-space mean $\Psi _{\bb{\beta}}(\bb{z})$ using our decoder network, and then sampling $\bb{x}$ from $\mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$.

Our variational encoder will approximate the posterior with a parametric distribution $q _{\bb{\alpha}}(\bb{Z} | \bb{x}) \sim \mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$. The interpretation is that our encoder neural network, $\Phi_{\vec{\alpha}}(\bb{x})$, calculates the mean and variance of the posterior distribution, and samples $\bb{z}$ based on them. An important nuance here is that our network can't contain any stochastic elements that depend on the model parameters, otherwise we won't be able to back-propagate to those parameters. So sampling $\bb{z}$ from $\mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$ is not an option. The solution is to use what's known as the reparametrization trick: sample from an isotropic Gaussian, i.e. $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ (which doesn't depend on trainable parameters), and calculate the latent representation as $\bb{z} = \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{u}\odot\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})$.

To train a VAE model, we would like to maximize the evidence, $p(\bb{X})$, because $ p(\bb{X}) = \int p(\bb{X}|{\bb{z}})p(\bb{z})d\bb{z} $ thus maximizing the likelihood of generated instances from over the entire latent space.

The VAE loss can therefore be stated as minimizing $\mathcal{L} = -\mathbb{E}_{\bb{x}} \log p(\bb{X})$. As we saw in the lecture, this expectation is intractable, but we can obtain a lower-bound for $p(\bb{X})$ (the evidence lower bound, "ELBO"):

$$ \log p(\bb{X}) \ge \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }( \log p _{\bb{\beta}}(\bb{X} | \bb{z}) )

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{X})\,\left|\, p(\bb{Z} )\right.\right) $$

where $ \mathcal{D} _{\mathrm{KL}}(q\left\|\right.p) = \mathbb{E}_{\bb{z}\sim q}\left[ \log \frac{q(\bb{Z})}{p(\bb{Z})} \right] $ is the Kullback-Liebler divergence, which can be interpreted as the information gained by using the posterior $q(\bb{Z|X})$ instead of the prior distribution $p(\bb{Z})$.

Using the ELBO, the VAE loss becomes, $$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }\left[ -\log p {\bb{\beta}}(\bb{x} | \bb{z}) \right]

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{x})\,\left|\, p(\bb{Z} )\right.\right) \right]. $$

By remembering that the likelihood is a Gaussian distribution with a diagonal covariance and by applying the reparametrization trick, we can write the above as

$$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} } \left[ \frac{1}{2\sigma^2}\left| \bb{x}- \Psi {\bb{\beta}}\left( \bb{\mu} {\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} {\bb{\alpha}}(\bb{x}) \bb{u} \right) \right| _2^2 \right]

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{x})\,\left|\, p(\bb{Z} )\right.\right) \right]. $$

Model Implementation

Obviously our model will have two parts, an encoder and a decoder. Since we're working with images, we'll implement both as deep convolutional networks, where the decoder is a "mirror image" of the encoder implemented with adjoint (AKA transposed) convolutions. Between the encoder CNN and the decoder CNN we'll implement the sampling from the parametric posterior approximator $q_{\bb{\alpha}}(\bb{Z}|\bb{x})$ to make it a VAE model and not just a regular autoencoder (of course, this is not yet enough to create a VAE, since we also need a special loss function which we'll get to later).

First let's implement just the CNN part of the Encoder network (this is not the full $\Phi_{\vec{\alpha}}(\bb{x})$ yet). As usual, it should take an input image and map to a activation volume of a specified depth. We'll consider this volume as the features we extract from the input image. Later we'll use these to create the latent space representation of the input. which will be our latent space representation.

TODO: Implement the EncoderCNN class in the hw3/autoencoder.py module. Implement any CNN architecture you like. If you need "architecture inspiration" you can see e.g. this or this paper.

In [6]:
import hw3.autoencoder as autoencoder

in_channels = 3
out_channels = 1024
encoder_cnn = autoencoder.EncoderCNN(in_channels, out_channels).to(device)
print(encoder_cnn)

h = encoder_cnn(x0)
print(h.shape)

test.assertEqual(h.dim(), 4)
test.assertSequenceEqual(h.shape[0:2], (1, out_channels))
EncoderCNN(
  (cnn): Sequential(
    (0): Conv2d(3, 250, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(250, 500, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(500, 750, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(750, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(750, 1000, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Conv2d(1000, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU()
  )
)
torch.Size([1, 1024, 2, 2])

Now let's implement the CNN part of the Decoder. Again this is not yet the full $\Psi _{\bb{\beta}}(\bb{z})$. It should take an activation volume produced by your EncoderCNN and output an image of the same dimensions as the Encoder's input was. This should be a CNN which is a "mirror image" of the the Encoder. For example, replace convolutions with transposed convolutions, downsampling with up-sampling etc. Consult the documentation of ConvTranspose2D to figure out how to reverse your convolutional layers in terms of input and output dimensions.

TODO: Implement the DecoderCNN class in the hw3/autoencoder.py module.

In [7]:
decoder_cnn = autoencoder.DecoderCNN(in_channels=out_channels, out_channels=in_channels).to(device)
print(decoder_cnn)
x0r = decoder_cnn(h)
print(x0r.shape)

test.assertEqual(x0.shape, x0r.shape)

# Should look like colored noise
T.functional.to_pil_image(x0r[0].cpu().detach())
DecoderCNN(
  (cnn): Sequential(
    (0): ConvTranspose2d(1024, 250, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ConvTranspose2d(250, 500, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ConvTranspose2d(500, 750, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
    (8): BatchNorm2d(750, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ConvTranspose2d(750, 1000, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): ReLU()
    (11): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ConvTranspose2d(1000, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): ReLU()
    (14): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
torch.Size([1, 3, 64, 64])
Out[7]:

Let's now implement the full VAE Encoder, $\Phi_{\vec{\alpha}}(\vec{x})$. It will work as follows:

  1. Produce a feature vector $\vec{h}$ from the input image $\vec{x}$.
  2. Use two affine transforms to convert the features into the mean and log-variance of the posterior, i.e. $$ \begin{align}
     \bb{\mu} _{\bb{\alpha}}(\bb{x}) &= \vec{h}\mattr{W}_{\mathrm{h\mu}} + \vec{b}_{\mathrm{h\mu}} \\
     \log\left(\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})\right) &= \vec{h}\mattr{W}_{\mathrm{h\sigma^2}} + \vec{b}_{\mathrm{h\sigma^2}}
    
    \end{align} $$
  3. Use the reparametrization trick to create the latent representation $\vec{z}$.

Note that we model the log of the variance, not the actual variance. The reason is that the log is easier to optimize, since (a) It doesn't have to be positive, and (b) it has a much larger dynamic range. The above formulation is proposed in appendix C of the VAE paper.

TODO: Implement the encode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__().

In [8]:
z_dim = 2
vae = autoencoder.VAE(encoder_cnn, decoder_cnn, x0[0].size(), z_dim).to(device)
print(vae)

z, mu, log_sigma2 = vae.encode(x0)

test.assertSequenceEqual(z.shape, (1, z_dim))
test.assertTrue(z.shape == mu.shape == log_sigma2.shape)

print(f'mu(x0)={list(*mu.detach().cpu().numpy())}, sigma2(x0)={list(*torch.exp(log_sigma2).detach().cpu().numpy())}')

# Sample from q(Z|x)
N = 500
Z = torch.zeros(N, z_dim)
_, ax = plt.subplots()
with torch.no_grad():
    for i in range(500):
        Z[i], _, _ = vae.encode(x0)
        ax.scatter(*Z[i].cpu().numpy())

# Should be close to the above
print('sampled mu', torch.mean(Z, dim=0))
print('sampled sigma2', torch.var(Z, dim=0))
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 250, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(250, 500, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(500, 750, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): BatchNorm2d(750, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Conv2d(750, 1000, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (10): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Conv2d(1000, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (13): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ConvTranspose2d(1024, 250, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ConvTranspose2d(250, 500, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): ReLU()
      (5): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ConvTranspose2d(500, 750, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): ReLU()
      (8): BatchNorm2d(750, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ConvTranspose2d(750, 1000, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (10): ReLU()
      (11): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ConvTranspose2d(1000, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (13): ReLU()
      (14): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (mu): Linear(in_features=4096, out_features=2, bias=True)
  (logvar): Linear(in_features=4096, out_features=2, bias=True)
  (rec): Linear(in_features=2, out_features=4096, bias=True)
)
mu(x0)=[0.062451795, -0.006458247], sigma2(x0)=[0.7509018, 1.2003356]
sampled mu tensor([ 0.0605, -0.0016])
sampled sigma2 tensor([0.1343, 0.3645])

Let's now implement the full VAE Decoder, $\Psi _{\bb{\beta}}(\bb{z})$. It will work as follows:

  1. Produce a feature vector $\tilde{\vec{h}}$ from the latent vector $\vec{z}$ using an affine transform.
  2. Reconstruct an image $\tilde{\vec{x}}$ from $\tilde{\vec{h}}$.

TODO: Implement the decode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__(). You may need to also re-run the block above after you implement this.

In [9]:
x0r = vae.decode(z)

test.assertSequenceEqual(x0r.shape, x0.shape)

Our model's forward() function will simply return decode(encode(x)) as well as the calculated mean and log-variance of the posterior.

In [10]:
x0r, mu, log_sigma2 = vae(x0)

test.assertSequenceEqual(x0r.shape, x0.shape)
test.assertSequenceEqual(mu.shape, (1, z_dim))
test.assertSequenceEqual(log_sigma2.shape, (1, z_dim))
T.functional.to_pil_image(x0r[0].detach().cpu())
Out[10]:

Loss Implementation

In practice, since we're using SGD, we'll drop the expectation over $\bb{X}$ and instead sample an instance from the training set and compute a point-wise loss. Similarly, we'll drop the expectation over $\bb{Z}$ by sampling from $q_{\vec{\alpha}}(\bb{Z}|\bb{x})$. Additionally, because the KL divergence is between two Gaussian distributions, there is a closed-form expression for it. These points bring us to the following point-wise loss:

$$ \ell(\vec{\alpha},\vec{\beta};\bb{x}) = \frac{1}{\sigma^2} \left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 + \mathrm{tr}\,\bb{\Sigma} _{\bb{\alpha}}(\bb{x}) + \|\bb{\mu} _{\bb{\alpha}}(\bb{x})\|^2 _2 - d_z - \log\det \bb{\Sigma} _{\bb{\alpha}}(\bb{x}) $$

where $d_z$ is the dimension of the latent space. This pointwise loss is the quantity that we'll compute and minimize with gradient descent.

TODO: Implement the vae_loss() function in the hw3/autoencoder.py module.

In [11]:
from hw3.autoencoder import vae_loss
torch.manual_seed(42)

def test_vae_loss():
    # Test data
    N, C, H, W = 10, 3, 64, 64 
    z_dim = 32
    x  = torch.randn(N, C, H, W)*2 - 1
    xr = torch.randn(N, C, H, W)*2 - 1
    z_mu = torch.randn(N, z_dim)
    z_log_sigma2 = torch.randn(N, z_dim)
    x_sigma2 = 0.9
    
    loss, _, _ = vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
    
    test.assertAlmostEqual(loss.item(), 10.5053434, delta=1e-5)
    return loss

test_vae_loss()
Out[11]:
tensor(10.5053)

Sampling

The main advantage of a VAE is that it can by used as a generative model by sampling the latent space, since we optimize for a Normal prior $p(\bb{Z})$ in the loss function. Let's now implement this so that we can visualize how our model is doing when we train.

TODO: Implement the sample() method in the VAE class within the hw3/autoencoder.py module.

In [12]:
samples = vae.sample(5)
_ = plot.tensors_as_images(samples)

Training

Time to train!

TODO:

  1. Implement the VAETrainer class in the hw3/training.py module.
  2. Tweak the hyperparameters in the part2_vae_hyperparam() function within the hw3/answers.py module.
In [13]:
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from hw3.training import VAETrainer
from hw3.answers import part2_vae_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part2_vae_hyperparams()
batch_size = hp['batch_size']
h_dim = hp['h_dim']
z_dim = hp['z_dim']
x_sigma2 = hp['x_sigma2']
learn_rate = hp['learn_rate']
betas = hp['betas']

# Data
split_lengths = [int(len(ds_gwb)*0.9), int(len(ds_gwb)*0.1)]
ds_train, ds_test = random_split(ds_gwb, split_lengths)
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_test  = DataLoader(ds_test,  batch_size, shuffle=True)
im_size = ds_train[0][0].shape

# Model
encoder = autoencoder.EncoderCNN(in_channels=im_size[0], out_channels=h_dim)
decoder = autoencoder.DecoderCNN(in_channels=h_dim, out_channels=im_size[0])
vae = autoencoder.VAE(encoder, decoder, im_size, z_dim)
vae_dp = DataParallel(vae).to(device)

# Optimizer
optimizer = optim.Adam(vae.parameters(), lr=learn_rate, betas=betas)

# Loss
def loss_fn(x, xr, z_mu, z_log_sigma2):
    return autoencoder.vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)

# Trainer
trainer = VAETrainer(vae_dp, loss_fn, optimizer, device)
checkpoint_file = 'checkpoints/vae'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show model and hypers
print(vae)
print(hp)
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 250, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(250, 500, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(500, 750, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): BatchNorm2d(750, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Conv2d(750, 1000, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (10): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Conv2d(1000, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (13): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ConvTranspose2d(1024, 250, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ConvTranspose2d(250, 500, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): ReLU()
      (5): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ConvTranspose2d(500, 750, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): ReLU()
      (8): BatchNorm2d(750, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ConvTranspose2d(750, 1000, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (10): ReLU()
      (11): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ConvTranspose2d(1000, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (13): ReLU()
      (14): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (mu): Linear(in_features=4096, out_features=2, bias=True)
  (logvar): Linear(in_features=4096, out_features=2, bias=True)
  (rec): Linear(in_features=2, out_features=4096, bias=True)
)
{'batch_size': 32, 'h_dim': 1024, 'z_dim': 2, 'x_sigma2': 5, 'learn_rate': 0.0005, 'betas': (0.9, 0.999)}
In [14]:
import IPython.display

def post_epoch_fn(epoch, train_result, test_result, verbose):
    # Plot some samples if this is a verbose epoch
    if verbose:
        samples = vae.sample(n=5)
        fig, _ = plot.tensors_as_images(samples, figsize=(6,2))
        IPython.display.display(fig)
        plt.close(fig)

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    checkpoint_file = checkpoint_file_final
else:
    res = trainer.fit(dl_train, dl_test,
                      num_epochs=200, early_stopping=20, print_every=10,
                      checkpoints=checkpoint_file,
                      post_epoch_fn=post_epoch_fn)
    
# Plot images from best model
saved_state = torch.load(f'{checkpoint_file}.pt', map_location=device)
vae_dp.load_state_dict(saved_state['model_state'])
print('*** Images Generated from best model:')
fig, _ = plot.tensors_as_images(vae_dp.module.sample(n=15), nrows=3, figsize=(6,6))
--- EPOCH 1/200 ---
train_batch (Avg. Loss 2.681, Accuracy 48.2): 100%|██████████| 15/15 [00:05<00:00,  3.01it/s]
test_batch (Avg. Loss 0.305, Accuracy 63.3): 100%|██████████| 2/2 [00:00<00:00,  4.95it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 1
*** Saved checkpoint checkpoints/vae.pt at epoch 3
*** Saved checkpoint checkpoints/vae.pt at epoch 6
*** Saved checkpoint checkpoints/vae.pt at epoch 9
--- EPOCH 11/200 ---
train_batch (Avg. Loss 0.057, Accuracy 55.7): 100%|██████████| 15/15 [00:05<00:00,  2.75it/s]
test_batch (Avg. Loss 0.055, Accuracy 69.0): 100%|██████████| 2/2 [00:00<00:00,  4.26it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 11
*** Saved checkpoint checkpoints/vae.pt at epoch 12
*** Saved checkpoint checkpoints/vae.pt at epoch 13
*** Saved checkpoint checkpoints/vae.pt at epoch 14
*** Saved checkpoint checkpoints/vae.pt at epoch 18
--- EPOCH 21/200 ---
train_batch (Avg. Loss 0.054, Accuracy 58.9): 100%|██████████| 15/15 [00:05<00:00,  3.03it/s]
test_batch (Avg. Loss 0.053, Accuracy 71.8): 100%|██████████| 2/2 [00:00<00:00,  4.89it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 21
*** Saved checkpoint checkpoints/vae.pt at epoch 22
*** Saved checkpoint checkpoints/vae.pt at epoch 24
--- EPOCH 31/200 ---
train_batch (Avg. Loss 0.052, Accuracy 60.4): 100%|██████████| 15/15 [00:04<00:00,  3.08it/s]
test_batch (Avg. Loss 0.051, Accuracy 74.8): 100%|██████████| 2/2 [00:00<00:00,  5.21it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 31
--- EPOCH 41/200 ---
train_batch (Avg. Loss 0.052, Accuracy 61.1): 100%|██████████| 15/15 [00:04<00:00,  2.89it/s]
test_batch (Avg. Loss 0.053, Accuracy 71.8): 100%|██████████| 2/2 [00:00<00:00,  4.10it/s]
--- EPOCH 51/200 ---
train_batch (Avg. Loss 0.052, Accuracy 61.4): 100%|██████████| 15/15 [00:04<00:00,  3.15it/s]
test_batch (Avg. Loss 0.052, Accuracy 72.2): 100%|██████████| 2/2 [00:00<00:00,  5.08it/s]
--- EPOCH 61/200 ---
train_batch (Avg. Loss 0.051, Accuracy 61.8): 100%|██████████| 15/15 [00:04<00:00,  3.12it/s]
test_batch (Avg. Loss 0.052, Accuracy 73.1): 100%|██████████| 2/2 [00:00<00:00,  4.96it/s]
--- EPOCH 71/200 ---
train_batch (Avg. Loss 0.051, Accuracy 61.7): 100%|██████████| 15/15 [00:04<00:00,  3.26it/s]
test_batch (Avg. Loss 0.052, Accuracy 72.3): 100%|██████████| 2/2 [00:00<00:00,  5.15it/s]
--- EPOCH 81/200 ---
train_batch (Avg. Loss 0.051, Accuracy 61.7): 100%|██████████| 15/15 [00:04<00:00,  3.16it/s]
test_batch (Avg. Loss 0.052, Accuracy 72.6): 100%|██████████| 2/2 [00:00<00:00,  4.96it/s]
--- EPOCH 91/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.3): 100%|██████████| 15/15 [00:05<00:00,  2.75it/s]
test_batch (Avg. Loss 0.051, Accuracy 74.0): 100%|██████████| 2/2 [00:00<00:00,  4.22it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 94
--- EPOCH 101/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.3): 100%|██████████| 15/15 [00:04<00:00,  3.09it/s]
test_batch (Avg. Loss 0.052, Accuracy 72.1): 100%|██████████| 2/2 [00:00<00:00,  4.92it/s]
--- EPOCH 111/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.4): 100%|██████████| 15/15 [00:04<00:00,  3.11it/s]
test_batch (Avg. Loss 0.051, Accuracy 73.7): 100%|██████████| 2/2 [00:00<00:00,  4.74it/s]
--- EPOCH 121/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.5): 100%|██████████| 15/15 [00:04<00:00,  3.12it/s]
test_batch (Avg. Loss 0.053, Accuracy 72.3): 100%|██████████| 2/2 [00:00<00:00,  4.97it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 122
--- EPOCH 131/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.3): 100%|██████████| 15/15 [00:05<00:00,  2.80it/s]
test_batch (Avg. Loss 0.052, Accuracy 73.4): 100%|██████████| 2/2 [00:00<00:00,  4.11it/s]
--- EPOCH 141/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.3): 100%|██████████| 15/15 [00:04<00:00,  3.17it/s]
test_batch (Avg. Loss 0.053, Accuracy 72.1): 100%|██████████| 2/2 [00:00<00:00,  5.35it/s]
--- EPOCH 151/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.3): 100%|██████████| 15/15 [00:05<00:00,  2.86it/s]
test_batch (Avg. Loss 0.053, Accuracy 72.2): 100%|██████████| 2/2 [00:00<00:00,  3.84it/s]
--- EPOCH 161/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.6): 100%|██████████| 15/15 [00:04<00:00,  3.09it/s]
test_batch (Avg. Loss 0.052, Accuracy 72.5): 100%|██████████| 2/2 [00:00<00:00,  4.57it/s]
--- EPOCH 171/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.5): 100%|██████████| 15/15 [00:04<00:00,  3.17it/s]
test_batch (Avg. Loss 0.053, Accuracy 72.1): 100%|██████████| 2/2 [00:00<00:00,  5.29it/s]
--- EPOCH 181/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.4): 100%|██████████| 15/15 [00:04<00:00,  3.01it/s]
test_batch (Avg. Loss 0.051, Accuracy 75.1): 100%|██████████| 2/2 [00:00<00:00,  5.04it/s]
--- EPOCH 191/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.6): 100%|██████████| 15/15 [00:04<00:00,  3.08it/s]
test_batch (Avg. Loss 0.052, Accuracy 74.0): 100%|██████████| 2/2 [00:00<00:00,  5.09it/s]
--- EPOCH 200/200 ---
train_batch (Avg. Loss 0.051, Accuracy 62.6): 100%|██████████| 15/15 [00:05<00:00,  2.88it/s]
test_batch (Avg. Loss 0.052, Accuracy 73.3): 100%|██████████| 2/2 [00:00<00:00,  5.40it/s]
*** Images Generated from best model:

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [15]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

What does the $\sigma^2$ hyperparameter (x_sigma2 in the code) do? Explain the effect of low and high values.

In [16]:
display_answer(hw3.answers.part2_q1)

The x_sigma2 hyperparameter is in charge of the relation between the data loss and the KL divergence loss in the calculation of the loss funtion. For high values the weight of the data loss will be small and the KL div loss will be high, and vice versa.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 3: Generative Adversarial Networks

In this part we will implement and train a generative adversarial network and apply it to the task of image generation.

In [2]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cuda

Obtaining the dataset

We'll use the same data as in Part 2.

But again, to use a custom dataset, edit the PART3_CUSTOM_DATA_URL variable in hw3/answers.py.

In [3]:
import cs236605.plot as plot
import cs236605.download
from hw3.answers import PART3_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236605.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/nevoagmon/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/nevoagmon/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/nevoagmon/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [4]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [5]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(10,5), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [6]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

Generative Adversarial Nets (GANs)

GANs, first proposed in a paper by Ian Goodfellow in 2014 are today arguably the most popular type of generative model. GANs are currently producing state of the art results in generative tasks over many different domains.

In a GAN model, two different neural networks compete against each other: A generator and a discriminator.

  • The Generator, which we'll denote as $\Psi _{\bb{\gamma}} : \mathcal{U} \rightarrow \mathcal{X}$, maps a latent-space variable $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ to an instance-space variable $\bb{x}$ (e.g. an image). Thus a parametric evidence distribution $p_{\bb{\gamma}}(\bb{X})$ is generated, which we typically would like to be as close as possible to the real evidence distribution, $p(\bb{X})$.

  • The Discriminator, $\Delta _{\bb{\delta}} : \mathcal{X} \rightarrow [0,1]$, is a network which, given an instance-space variable $\bb{x}$, returns the probability that $\bb{x}$ is real, i.e. that $\bb{x}$ was sampled from $p(\bb{X})$ and not $p_{\bb{\gamma}}(\bb{X})$.

Training GANs

The generator is trained to generate "fake" instances which will maximally fool the discriminator into returning that they're real. Mathematically, the generator's parameters $\bb{\gamma}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

The discriminator is trained to classify between real images, coming from the training set, and fake images generated by the generator. Mathematically, the discriminator's parameters $\bb{\delta}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

These two competing objectives can thus be expressed as the following min-max optimization: $$ \min _{\bb{\gamma}} \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

A key insight into GANs is that we can interpret the above maximum as the loss with respect to $\bb{\gamma}$:

$$ L({\bb{\gamma}}) = \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

This means that the generator's loss function trains together with the generator itself in an adversarial manner. In contrast, when training our VAE we used a fixed L2 norm as a data loss term.

Model Implementation

We'll now implement a Deep Convolutional GAN (DCGAN) model. See the DCGAN paper for architecture ideas and tips for training.

TODO: Implement the Discriminator class in the hw3/gan.py module. If you wish you can reuse the EncoderCNN class from the VAE model as the first part of the Discriminator.

In [7]:
import hw3.gan as gan

dsc = gan.Discriminator(in_size=x0[0].shape).to(device)
print(dsc)

d0 = dsc(x0)
print(d0.shape)

test.assertSequenceEqual(d0.shape, (1,1))
Discriminator(
  (cnn): Sequential(
    (0): Conv2d(3, 250, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(250, 500, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(500, 750, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(750, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(750, 1000, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
  )
  (classifier): Sequential(
    (0): Linear(in_features=16000, out_features=4, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2)
    (3): Linear(in_features=4, out_features=1, bias=True)
  )
)
torch.Size([1, 1])

TODO: Implement the Generator class in the hw3/gan.py module. If you wish you can reuse the DecoderCNN class from the VAE model as the last part of the Generator.

In [31]:
z_dim = 128
gen = gan.Generator(z_dim, 4).to(device)
print(gen)

z = torch.randn(1, z_dim).to(device)
xr = gen(z)
print(xr.shape)

test.assertSequenceEqual(x0.shape, xr.shape)
Generator(
  (generator): Sequential(
    (0): ConvTranspose2d(128, 250, kernel_size=(4, 4), stride=(2, 2))
    (1): ReLU()
    (2): BatchNorm2d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ConvTranspose2d(250, 500, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ConvTranspose2d(500, 750, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
    (8): BatchNorm2d(750, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ConvTranspose2d(750, 1000, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): ReLU()
    (11): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ConvTranspose2d(1000, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): ReLU()
    (14): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
torch.Size([1, 3, 64, 64])

Loss Implementation

Let's begin with the discriminator's loss function. Based on the above we can flip the sign and say we want to update the Discriminator's parameters $\bb{\delta}$ so that they minimize the expression $$

  • \mathbb{E} {\bb{x} \sim p(\bb{X}) } \log \Delta {\bb{\delta}}(\bb{x}) \, - \, \mathbb{E} {\bb{z} \sim p(\bb{Z}) } \log (1-\Delta {\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

We're using the Discriminator twice in this expression; once to classify data from the real data distribution and once again to classify generated data. Therefore our loss should be computed based on these two terms. Notice that since the discriminator returns a probability, we can formulate the above as two cross-entropy losses.

GANs are notoriously diffucult to train. One common trick for improving GAN stability during training is to make the classification labels noisy for the discriminator. This can be seen as a form of regularization, to help prevent the discriminator from overfitting.

We'll incorporate this idea into our loss function. Instead of labels being equal to 0 or 1, we'll make them "fuzzy", i.e. random numbers in the ranges $[0\pm\epsilon]$ and $[1\pm\epsilon]$.

TODO: Implement the discriminator_loss_fn() function in the hw3/gan.py module.

In [9]:
from hw3.gan import discriminator_loss_fn
torch.manual_seed(42)

y_data = torch.rand(10) * 10
y_generated = torch.rand(10) * 10

loss = discriminator_loss_fn(y_data, y_generated, data_label=1, label_noise=0.3)
print(loss)

test.assertAlmostEqual(loss.item(), 6.4808731, delta=1e-5)
tensor(6.4809)

Similarly, the generator's parameters $\bb{\gamma}$ should minimize the expression $$ -\mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )) $$

which can also be seen as a cross-entropy term.

TODO: Implement the generator_loss_fn() function in the hw3/gan.py module.

In [10]:
from hw3.gan import generator_loss_fn
torch.manual_seed(42)

y_generated = torch.rand(20) * 10

loss = generator_loss_fn(y_generated, data_label=1)
print(loss)

test.assertAlmostEqual(loss.item(), 0.0222969, delta=1e-5)
tensor(0.0223)

Sampling

Sampling from a GAN is straightforward, since it learns to generate data from an isotropic Gaussian latent space distribution.

There is an important nuance however. Sampling is required during the process of training the GAN, since we generate fake images to show the discriminator. As you'll seen in the next section, in some cases we'll need our samples to have gradients.

TODO: Implement the sample() method in the Generator class within the hw3/gan.py module.

In [11]:
samples = gen.sample(5, with_grad=False)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNone(samples.grad_fn)
_ = plot.tensors_as_images(samples.cpu())

samples = gen.sample(5, with_grad=True)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNotNone(samples.grad_fn)

Training

Training GANs is a bit different since we need to train two models simultaneously, each with it's own separate loss function and optimizer. We'll implement the training logic as a function that handles one batch of data and updates both the discriminator and the generator based on it.

As mentioned above, GANs are considered hard to train. To get some ideas and tips you can see this paper, this list of "GAN hacks" or just do it the hard way :)

TODO:

  1. Implement the train_batch function in the hw3/gan.py module.
  2. Tweak the hyperparameters in the part3_gan_hyperparam() function within the hw3/answers.py module.
In [57]:
import torch.optim as optim
from torch.utils.data import DataLoader
from hw3.answers import part3_gan_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part3_gan_hyperparams()
batch_size = hp['batch_size']
z_dim = hp['z_dim']

# Data
dl_train = DataLoader(ds_gwb, batch_size, shuffle=True)
im_size = ds_gwb[0][0].shape

# Model
dsc = gan.Discriminator(im_size).to(device)
gen = gan.Generator(z_dim, featuremap_size=4).to(device)

# Optimizer
def create_optimizer(model_params, opt_params):
    opt_params = opt_params.copy()
    optimizer_type = opt_params['type']
    opt_params.pop('type')
    return optim.__dict__[optimizer_type](model_params, **opt_params)
dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])

# Loss
def dsc_loss_fn(y_data, y_generated):
    return gan.discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])

def gen_loss_fn(y_generated):
    return gan.generator_loss_fn(y_generated, hp['data_label'])

# Training
checkpoint_file = 'checkpoints/gan'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show hypers
print(hp)
{'batch_size': 32, 'z_dim': 10, 'data_label': 1, 'label_noise': 0.3, 'discriminator_optimizer': {'type': 'Adam', 'lr': 0.0002, 'weight_decay': 0.002, 'betas': (0.5, 0.999)}, 'generator_optimizer': {'type': 'Adam', 'lr': 0.0002, 'weight_decay': 0.002, 'betas': (0.4, 0.999)}}
In [59]:
import IPython.display
import tqdm
from hw3.gan import train_batch

num_epochs = 100

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    num_epochs = 0
    gen = torch.load(f'{checkpoint_file_final}.pt', map_location=device)
    checkpoint_file = checkpoint_file_final

for epoch_idx in range(num_epochs):
    # We'll accumulate batch losses and show an average once per epoch.
    dsc_losses = []
    gen_losses = []
    print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')
    
    with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
        for batch_idx, (x_data, _) in enumerate(dl_train):
            x_data = x_data.to(device)
            dsc_loss, gen_loss = train_batch(
                dsc, gen,
                dsc_loss_fn, gen_loss_fn,
                dsc_optimizer, gen_optimizer,
                x_data)
            dsc_losses.append(dsc_loss)
            gen_losses.append(gen_loss)
            pbar.update()

    dsc_avg_loss, gen_avg_loss = np.mean(dsc_losses), np.mean(gen_losses)
    print(f'Discriminator loss: {dsc_avg_loss}')
    print(f'Generator loss:     {gen_avg_loss}')
        
    samples = gen.sample(5, with_grad=False)
    fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
    IPython.display.display(fig)
    plt.close(fig)
--- EPOCH 1/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.99it/s]
Discriminator loss: 0.6429693652864765
Generator loss:     7.152631899889777
--- EPOCH 2/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.91it/s]
Discriminator loss: 1.1279077284476335
Generator loss:     6.100880160051234
--- EPOCH 3/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.97it/s]
Discriminator loss: 1.149587852113387
Generator loss:     4.052209903212154
--- EPOCH 4/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 0.6314365495653713
Generator loss:     2.737534298616297
--- EPOCH 5/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.86it/s]
Discriminator loss: 0.44448774702408733
Generator loss:     4.112388148027308
--- EPOCH 6/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.90it/s]
Discriminator loss: 0.31292344542110667
Generator loss:     4.673804493511424
--- EPOCH 7/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.78it/s]
Discriminator loss: 1.048147408401265
Generator loss:     3.6689783615224503
--- EPOCH 8/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.81it/s]
Discriminator loss: 0.3827902551959543
Generator loss:     3.9134428360882927
--- EPOCH 9/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.78it/s]
Discriminator loss: 0.4209982060334262
Generator loss:     4.4366773717543655
--- EPOCH 10/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.91it/s]
Discriminator loss: 0.6943186030668371
Generator loss:     3.741685656940236
--- EPOCH 11/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.95it/s]
Discriminator loss: 0.8188287072321948
Generator loss:     3.738111783476437
--- EPOCH 12/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.91it/s]
Discriminator loss: 0.5261569768190384
Generator loss:     3.8838668167591095
--- EPOCH 13/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.96it/s]
Discriminator loss: 1.0280034261591293
Generator loss:     3.271259209688972
--- EPOCH 14/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.93it/s]
Discriminator loss: 0.7997743925627541
Generator loss:     3.536908374113195
--- EPOCH 15/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.93it/s]
Discriminator loss: 0.8743297597941231
Generator loss:     3.229924566605512
--- EPOCH 16/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.97it/s]
Discriminator loss: 0.8986433130853316
Generator loss:     3.0789449670735527
--- EPOCH 17/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.90it/s]
Discriminator loss: 0.6708083959186778
Generator loss:     3.5620327837326946
--- EPOCH 18/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.93it/s]
Discriminator loss: 0.55542272679946
Generator loss:     3.6218428822124706
--- EPOCH 19/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 1.0900565455941593
Generator loss:     3.155969623257132
--- EPOCH 20/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.89it/s]
Discriminator loss: 0.6741697805769303
Generator loss:     3.466335822554196
--- EPOCH 21/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.76it/s]
Discriminator loss: 1.042045130449183
Generator loss:     2.6953003424055435
--- EPOCH 22/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.80it/s]
Discriminator loss: 0.7716695561128504
Generator loss:     3.1776481411036324
--- EPOCH 23/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.85it/s]
Discriminator loss: 0.9586734350989846
Generator loss:     2.7567259003134335
--- EPOCH 24/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.90it/s]
Discriminator loss: 0.8427364054848167
Generator loss:     3.086686604163226
--- EPOCH 25/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.85it/s]
Discriminator loss: 0.8760731185183805
Generator loss:     3.3450410646550797
--- EPOCH 26/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.93it/s]
Discriminator loss: 0.8785888938342824
Generator loss:     3.037569414166843
--- EPOCH 27/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.87it/s]
Discriminator loss: 0.9628784253316767
Generator loss:     2.451293682350832
--- EPOCH 28/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.70it/s]
Discriminator loss: 1.019276415600496
Generator loss:     2.3465735807138333
--- EPOCH 29/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.86it/s]
Discriminator loss: 0.9531389913138222
Generator loss:     2.6253413803437176
--- EPOCH 30/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.81it/s]
Discriminator loss: 1.0157475962358362
Generator loss:     2.161858139669194
--- EPOCH 31/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.79it/s]
Discriminator loss: 0.7807056220138774
Generator loss:     2.815427576794344
--- EPOCH 32/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.88it/s]
Discriminator loss: 1.0692532658576965
Generator loss:     2.5785896830699024
--- EPOCH 33/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.87it/s]
Discriminator loss: 1.0597794161123388
Generator loss:     2.1992763999630425
--- EPOCH 34/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.89it/s]
Discriminator loss: 0.9788953381426194
Generator loss:     2.4706613000701454
--- EPOCH 35/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.86it/s]
Discriminator loss: 1.0351756144972408
Generator loss:     2.1890082937829636
--- EPOCH 36/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.84it/s]
Discriminator loss: 0.9825181400074678
Generator loss:     2.297637501183678
--- EPOCH 37/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 0.8581356897073633
Generator loss:     2.6314883021747364
--- EPOCH 38/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.96it/s]
Discriminator loss: 0.8965897910735187
Generator loss:     2.650883625535404
--- EPOCH 39/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.88it/s]
Discriminator loss: 0.9855937326655668
Generator loss:     2.431456586893867
--- EPOCH 40/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.92it/s]
Discriminator loss: 0.8133983506875879
Generator loss:     2.675472897641799
--- EPOCH 41/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 1.0717309187440311
Generator loss:     2.2148075699806213
--- EPOCH 42/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.89it/s]
Discriminator loss: 1.00506209099994
Generator loss:     2.382066582932192
--- EPOCH 43/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.96it/s]
Discriminator loss: 0.9061266148791594
Generator loss:     2.4974488300435684
--- EPOCH 44/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.95it/s]
Discriminator loss: 0.8686574346878949
Generator loss:     2.8221396733732784
--- EPOCH 45/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.93it/s]
Discriminator loss: 0.9775093092637903
Generator loss:     2.606570047490737
--- EPOCH 46/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 0.9571589441860423
Generator loss:     2.6859100145452164
--- EPOCH 47/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.98it/s]
Discriminator loss: 0.8417629389201894
Generator loss:     2.7056742976693546
--- EPOCH 48/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.96it/s]
Discriminator loss: 0.7238037147942711
Generator loss:     2.983083451495451
--- EPOCH 49/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.92it/s]
Discriminator loss: 0.9411259538987103
Generator loss:     3.0449262015959797
--- EPOCH 50/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.88it/s]
Discriminator loss: 0.8180542886257172
Generator loss:     3.011088441399967
--- EPOCH 51/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.86it/s]
Discriminator loss: 0.8471382751184351
Generator loss:     2.8986567048465504
--- EPOCH 52/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.96it/s]
Discriminator loss: 0.7820537651286406
Generator loss:     3.1060519569060383
--- EPOCH 53/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.97it/s]
Discriminator loss: 0.7699516766211566
Generator loss:     3.1389779132955216
--- EPOCH 54/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.95it/s]
Discriminator loss: 0.8486169979852789
Generator loss:     3.137370404075174
--- EPOCH 55/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.00it/s]
Discriminator loss: 0.6681408549056334
Generator loss:     3.320000003365909
--- EPOCH 56/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.92it/s]
Discriminator loss: 0.698312899645637
Generator loss:     3.3518648831283344
--- EPOCH 57/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.93it/s]
Discriminator loss: 1.0867561101913452
Generator loss:     2.8441466093063354
--- EPOCH 58/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.02it/s]
Discriminator loss: 0.8884692121954525
Generator loss:     3.0175412682925953
--- EPOCH 59/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.97it/s]
Discriminator loss: 0.5969304091790143
Generator loss:     3.2264910024755142
--- EPOCH 60/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.85it/s]
Discriminator loss: 0.7342718664337607
Generator loss:     3.602457242853501
--- EPOCH 61/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.93it/s]
Discriminator loss: 0.7789926493869108
Generator loss:     3.3017596076516544
--- EPOCH 62/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.87it/s]
Discriminator loss: 0.7295236517401302
Generator loss:     3.324652489493875
--- EPOCH 63/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.91it/s]
Discriminator loss: 0.5288006417891559
Generator loss:     3.7635633384480194
--- EPOCH 64/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.88it/s]
Discriminator loss: 0.6833070499055526
Generator loss:     3.6033018476822796
--- EPOCH 65/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.88it/s]
Discriminator loss: 0.6058792349170236
Generator loss:     3.6700223263572243
--- EPOCH 66/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.87it/s]
Discriminator loss: 0.7250925714478773
Generator loss:     3.6562275956658756
--- EPOCH 67/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.95it/s]
Discriminator loss: 0.6995201180962956
Generator loss:     3.6700789086958943
--- EPOCH 68/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.85it/s]
Discriminator loss: 0.4788511509404463
Generator loss:     4.084004261914422
--- EPOCH 69/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.84it/s]
Discriminator loss: 0.45396612146321463
Generator loss:     3.9392987840315876
--- EPOCH 70/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.92it/s]
Discriminator loss: 0.7961510419845581
Generator loss:     4.045468042878544
--- EPOCH 71/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 0.7774021923542023
Generator loss:     3.4907390931073357
--- EPOCH 72/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.95it/s]
Discriminator loss: 0.5544340321246315
Generator loss:     4.133463866570416
--- EPOCH 73/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.96it/s]
Discriminator loss: 0.7337903275209314
Generator loss:     3.6170143169515274
--- EPOCH 74/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.95it/s]
Discriminator loss: 0.6153087703620687
Generator loss:     3.9992057856391456
--- EPOCH 75/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.90it/s]
Discriminator loss: 0.5721007673179402
Generator loss:     4.09680157549241
--- EPOCH 76/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.91it/s]
Discriminator loss: 0.41280702983631806
Generator loss:     4.395288691801183
--- EPOCH 77/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.90it/s]
Discriminator loss: 0.4556984866366667
Generator loss:     4.6250403067644905
--- EPOCH 78/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 0.5250876011217341
Generator loss:     4.527703397414264
--- EPOCH 79/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.89it/s]
Discriminator loss: 0.712657663752051
Generator loss:     4.126779693014481
--- EPOCH 80/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.89it/s]
Discriminator loss: 0.5055013246396008
Generator loss:     4.214290745118085
--- EPOCH 81/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.91it/s]
Discriminator loss: 0.4368143607588375
Generator loss:     4.280328890856574
--- EPOCH 82/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 0.3901726410669439
Generator loss:     4.489644541459925
--- EPOCH 83/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 0.5570309547817006
Generator loss:     4.843409419059753
--- EPOCH 84/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.87it/s]
Discriminator loss: 0.5582765971913057
Generator loss:     4.488968218074126
--- EPOCH 85/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.92it/s]
Discriminator loss: 0.6052223461515763
Generator loss:     4.482523805954877
--- EPOCH 86/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.88it/s]
Discriminator loss: 0.44817916347700004
Generator loss:     4.395666038288789
--- EPOCH 87/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.95it/s]
Discriminator loss: 0.41032591549789205
Generator loss:     4.619030026828542
--- EPOCH 88/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.94it/s]
Discriminator loss: 0.39996321762309356
Generator loss:     4.804345719954547
--- EPOCH 89/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.87it/s]
Discriminator loss: 0.603129514876534
Generator loss:     4.939912922242108
--- EPOCH 90/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.91it/s]
Discriminator loss: 0.5676895380020142
Generator loss:     4.496359180001652
--- EPOCH 91/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.91it/s]
Discriminator loss: 0.3394545255338444
Generator loss:     4.61476980938631
--- EPOCH 92/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.88it/s]
Discriminator loss: 0.36849100186544304
Generator loss:     4.787350416183472
--- EPOCH 93/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.93it/s]
Discriminator loss: 0.4371252533267526
Generator loss:     5.095597183003145
--- EPOCH 94/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.84it/s]
Discriminator loss: 0.7409432530403137
Generator loss:     4.271258185891544
--- EPOCH 95/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.80it/s]
Discriminator loss: 0.4629165582797107
Generator loss:     4.572179261375876
--- EPOCH 96/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.86it/s]
Discriminator loss: 0.31437837814583497
Generator loss:     4.778341615901274
--- EPOCH 97/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.86it/s]
Discriminator loss: 0.3229763875989353
Generator loss:     5.088379677604227
--- EPOCH 98/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.92it/s]
Discriminator loss: 0.25061708622995543
Generator loss:     5.63083675328423
--- EPOCH 99/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.92it/s]
Discriminator loss: 0.9091886382769135
Generator loss:     4.938453421873205
--- EPOCH 100/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.99it/s]
Discriminator loss: 0.45473635328166623
Generator loss:     4.318602309507482
In [60]:
# Plot images from best or last model
if os.path.isfile(f'{checkpoint_file}.pt'):
    gen = torch.load(f'{checkpoint_file}.pt', map_location=device)
print('*** Images Generated from best model:')
samples = gen.sample(n=15, with_grad=False).cpu()
fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))
*** Images Generated from best model:

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [1]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

Explain in detail why during training we sometimes need to maintain gradients when sampling from the GAN, and other times we don't. When are they maintained and why? When are they discarded and why?

In [2]:
display_answer(hw3.answers.part3_q1)

We maintain the gradient when we sample from the generator in the batch training function. The reason is that when we are training we want the gradient so that we can optimize the result of the generator. In all other occasions we do not maintain the gradient so that we will not change its value and ruine the training.

Question 2

  1. When training a GAN to generate images, should we decide to stop training solely based on the fact that the Generator loss is below some threshold? Why or why not?

  2. What does it mean if the discriminator loss remains at a constant value while the generator loss decreases?

In [3]:
display_answer(hw3.answers.part3_q2)

We can't decide to stop training based on the generator loss being bellow a certain threshold because the loos of the generator and the loss of the descriminator are connected. We can imagine a situation where the loss of the generator is very low (so we would think to stop training) but in the next batch the descriminator will sudenlly improve and find new differences between the real and fake images therefore the loss of the generator will go back up. If we get into a situation where the loss of the descriminator is constant but the loos of the generator keeps improving then we are in a situation where the descriminator can no longer tell the difference between the real and fake images but the generator keeps making the images better and better in comparison to the real ones.

Question 2

Compare the results you got when generating images with the VAE to the GAN results. What's the main difference and what's causing it?

In [4]:
display_answer(hw3.answers.part3_q3)

The main difference between the results from the VAE and from the GAN is that the with the VAE we get a clear face with almost no background but with good details and with the GAN we get a better overall image in terms of the background and the face in it but with less fine details. We think that the it is happaning because of the way the models are built, optimized and trained. The VAE tries to extract the most importent and reoccuring features and convert them to a representation in the latent space, that is why we get a good result for the fine details of the face. In comparison the GAN tries to optimize the generated image in relation to its ability to distinguish between a real and fake image, therefore it produces images that as a whole look similar to the dataset but the fine details are less importent.